# -*- coding: utf-8 -*-
"""
Created on Fri Aug 30 16:09:09 2019

@author: varys
"""

import scipy.io as sio
subject_no = 1
subject_names = ['wangkui_20140620.mat','xiayulu_20140527.mat','liuqiujun_20140621.mat','sunxiangyu_20140511.mat','dujingcheng_20131027.mat']
abbreviation = ['wk','xyl','lqj','sxy','djc']
eegdata = sio.loadmat(subject_names[subject_no])

channels_name = ['FP1','FPZ','FP2','AF3','AF4','F7','F5','F3','F1','FZ','F2','F4','F6','F8','FT7','FC5','FC3','FC1','FCZ','FC2','FC4','FC6','FT8','T7','C5','C3','C1','CZ','C2','C4','C6','T8','TP7','CP5','CP3','CP1','CPZ','CP2','CP4','CP6','TP8','P7','P5','P3','P1','PZ','P2','P4','P6','P8','PO7','PO5','PO3','POZ','PO4','PO6','PO8','CB1','O1','OZ','O2','CB2']
import pandas as pd
import matplotlib.pyplot as plt
labels = [1,0,-1,-1,0,1,-1,0,1,1,0,-1,0,1,-1]
label_meanings = {1:"positive", 0:"neutral", -1:"negative"}
extrapolations = ['box', 'head', 'local']
trail_num = 3

def plot_trail_topo(trail_no):
    """
    trail_no : 试验编号
    ax : 画布对象
    """

    trail_data = pd.DataFrame(eegdata[abbreviation[subject_no]+'_eeg'+str(trail_no)], index=channels_name)
    #trail_data = trail_data.loc[['P1','PO6'],:]
   
    # 创建mne信息对象
    import mne
    sfreq = 200 # 采样频率
    show_channel_begin = 0
    show_channel_end = trail_data.shape[0]
    seed_montage = mne.channels.read_montage('seed')
    ch_types = ['eeg']*len(channels_name[show_channel_begin:show_channel_end])
    eeg_info = mne.create_info(ch_names=trail_data.index.tolist(),sfreq=sfreq, ch_types=ch_types,montage=seed_montage)
    
    raw_arr =mne.io.RawArray(trail_data, eeg_info)
    
    
    scalings = 'auto'  # Could also pass a dictionary with some value == 'auto'
    # 源数据折线图
    #raw_arr.plot(scalings=scalings, title=label_meanings[labels[trail_no-1]], show=True, block=False).savefig('./raw_data_trail'+str(trail_no)+label_meanings[labels[trail_no-1]]+'.svg')
    raw_arr.plot_psd(fmax=75).savefig('./spectral_trail'+str(trail_no)+label_meanings[labels[trail_no-1]]+'.svg')
    
    pass

for i in range(trail_num):
    trail_no = i+1
    plot_trail_topo(trail_no)

plt.show()


